{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate and Save Sampling(IME)/Kernel Explainer convergence details for a simulated data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xgboost\n",
    "import numpy as np\n",
    "import shap\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pylab as pl\n",
    "from shap import KernelExplainer\n",
    "from shap import SamplingExplainer\n",
    "N = 1000\n",
    "Ms = [20,30,40,50,60,70,80,90]\n",
    "\n",
    "np.random.seed(3011)\n",
    "ime_std_lst =[]\n",
    "ime_m_lst = []\n",
    "kernel_shap_std_lst =[]\n",
    "kernel_shap_m_lst = []\n",
    "for l in range(0,30):\n",
    "    X_full = np.random.randn(N, 100)\n",
    "    y = np.random.randn(N)\n",
    "    kernel_shap_nsamples = []\n",
    "    ime_nsamples = []\n",
    "    tree_shap_std = []\n",
    "    kernel_shap_std = []\n",
    "    kernel_shap_m = []\n",
    "    ime_std = []\n",
    "    ime_m = []\n",
    "    for M in Ms:\n",
    "        print(\"\\nM\", M)\n",
    "        X = X_full[:,:M]\n",
    "        model = xgboost.train({\"silent\":1}, xgboost.DMatrix(X, y), 1000)\n",
    "\n",
    "        def f(x):\n",
    "            return model.predict(xgboost.DMatrix(x))\n",
    "\n",
    "        e = shap.TreeExplainer(model)\n",
    "        e.shap_values(X)\n",
    "        tree_shap_std.append(0)\n",
    "\n",
    "        print(\"Kernel Explainer\")    \n",
    "        nsamples = 1000 * M\n",
    "        print(nsamples)\n",
    "        e = KernelExplainer(f, X.mean(0).reshape(1,M))\n",
    "        out = np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=nsamples) for i in range(10)])\n",
    "        std_dev = out.std(0)[:-1].mean()\n",
    "        mval = np.abs(out.mean(0))[:-1].mean()\n",
    "        print(std_dev, mval, std_dev / mval)\n",
    "        kernel_shap_nsamples.append(nsamples)\n",
    "        kernel_shap_std.append(std_dev)\n",
    "        kernel_shap_m.append(mval)\n",
    "\n",
    "        nsamples = 1000 * M\n",
    "        print(\"Sample Explainer\")\n",
    "        print(nsamples)\n",
    "        e = SamplingExplainer(f, X.mean(0).reshape(1,M))\n",
    "        out = np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=nsamples) for i in range(10)])\n",
    "        std_dev = out.std(0)[:-1].mean()\n",
    "        mval = np.abs(out.mean(0))[:-1].mean()\n",
    "        print(std_dev, mval, std_dev / mval)\n",
    "        ime_nsamples.append(nsamples)\n",
    "        ime_std.append(std_dev)\n",
    "        ime_m.append(mval)\n",
    "\n",
    "    ime_std_lst.append(ime_std)\n",
    "    ime_m_lst.append(ime_m)\n",
    "    kernel_shap_std_lst.append(kernel_shap_std)\n",
    "    kernel_shap_m_lst.append(kernel_shap_m)\n",
    "np.save(\"scratch/kernel_shap_std_lst\",np.array(kernel_shap_std_lst))\n",
    "np.save(\"scratch/ime_std_lst\",np.array(ime_std_lst))\n",
    "np.save(\"scratch/kernel_shap_m_lst\",np.array(kernel_shap_m_lst))\n",
    "np.save(\"scratch/ime_m_lst\",np.array(ime_m_lst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
